Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added kl_divergence for multivariate normals #1654

Merged
merged 3 commits into from
Oct 27, 2023

Conversation

lumip
Copy link
Contributor

@lumip lumip commented Sep 29, 2023

KL implementation for MultivariateNormal

EDIT: just saw that there is already a pull request for this from a year ago (#1487). What's the hold up with that?

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @lumip! The author of #1487 did not response so I think we can use this PR instead. I have a few comments below to address the batching issues.

numpyro/distributions/kl.py Outdated Show resolved Hide resolved
numpyro/distributions/kl.py Outdated Show resolved Hide resolved
numpyro/distributions/kl.py Outdated Show resolved Hide resolved
making the linter tests happy
@lumip
Copy link
Contributor Author

lumip commented Oct 20, 2023

Thanks for the feedback @fehiepsi . I have modified the code according to your suggestions and also added a test case where the batch dimensions of the two distributions are not identical (and fixed the linter complaints). If everything is good now, I'd still rebase onto the recent master and merge the commits.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @lumip! I have a few small comments on the assertions.


Lq_inv = solve_triangular(q_scale_tril, jnp.eye(D), lower=True)
q_half_log_det = jnp.log(jnp.diagonal(q.scale_tril, axis1=-2, axis2=-1)).sum(-1)
assert q_half_log_det.shape == q.batch_shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this assertion might not be true. In MultivariateNormal implementation, we avoid unnecessary broadcasting (e.g. we can have a batch of means with a single scale_tril).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I wasn't thinking of that. Removed the assertion and added tests for those cases.

f" {p.event_shape} and {q.event_shape} for p and q, respectively."
)

if p.batch_shape != q.batch_shape:
min_batch_ndim = min(len(p.batch_shape), len(q.batch_shape))
if p.batch_shape[-min_batch_ndim:] != q.batch_shape[-min_batch_ndim:]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about only assert that p.batch_shape and q.batch_shape can be broadcasted.

try:
    result_batch_shape = jnp.broadcast_shapes(p.batch_shape, q.batch_shape)
except ValueError:
    raise ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

assert jnp.ndim(q_mean) == 1
assert jnp.ndim(p_scale_tril) == 2
assert jnp.ndim(q_scale_tril) == 2
assert q.mean.shape == q.batch_shape + q.event_shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those assertions are unnecessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


return .5 * (tr + t1 - D - log_det_ratio)
tr = _batch_trace_from_cholesky(Lq_inv @ p.scale_tril)
assert tr.shape == result_batch_shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assertion might not be true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

p_mean_flat = jnp.reshape(p.mean, (-1, D))
p_scale_tril_flat = jnp.reshape(p.scale_tril, (-1, D, D))
t1 = jnp.square(Lq_inv @ (p.loc - q.loc)[..., jnp.newaxis]).sum((-2, -1))
assert t1.shape == result_batch_shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assertion might not be true

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

).sum(-1)
log_det_ratio = 2 * (p_half_log_det - q_half_log_det)
p_half_log_det = jnp.log(jnp.diagonal(p.scale_tril, axis1=-2, axis2=-1)).sum(-1)
assert p_half_log_det.shape == p.batch_shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assertion might not be true

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

((), ()),
((1,), (1,)),
((2, 3), (2, 3)),
((5, 2, 3), (2, 3)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you change this to (5, 1, 3) and (2, 3)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

((1,), (1,)),
((2, 3), (2, 3)),
((5, 2, 3), (2, 3)),
((2, 3), (5, 2, 3)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe ((1, 3), (5, 2, 3))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @lumip!

@fehiepsi fehiepsi merged commit eaa29a0 into pyro-ppl:master Oct 27, 2023
4 checks passed
@fehiepsi fehiepsi mentioned this pull request Oct 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants